{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 代码实现"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们首先编写DeepPose的代码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision.models as models\n",
    "\n",
    "class DeepPose(nn.Module):\n",
    "    def __init__(self, num_keypoints=17, pretrained=True, num_stages=3):\n",
    "        super(DeepPose, self).__init__()\n",
    "        self.num_keypoints = num_keypoints\n",
    "        self.num_stages = num_stages\n",
    "\n",
    "        # 加载预训练的ResNet模型用于特征提取\n",
    "        self.backbone = models.resnet50(pretrained=pretrained)\n",
    "        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])\n",
    "\n",
    "        # 为每个级联阶段定义回归层\n",
    "        self.regression_layers = nn.ModuleList([\n",
    "            self._make_regression_layer() for _ in range(num_stages)\n",
    "        ])\n",
    "\n",
    "        # 每个级联阶段的最终全连接层，用于关节点预测\n",
    "        self.fc_layers = nn.ModuleList([\n",
    "            nn.Linear(2048, num_keypoints * 2) for _ in range(num_stages)\n",
    "        ])\n",
    "\n",
    "    def _make_regression_layer(self):\n",
    "        # 定义回归层，由几个卷积层组成\n",
    "        return nn.Sequential(\n",
    "            nn.Conv2d(2048, 512, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(512),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(512, 512, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(512),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(512, 2048, kernel_size=1)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        # 使用骨干网络提取特征\n",
    "        features = self.backbone(x)\n",
    "\n",
    "        # 初始化关节点预测为零\n",
    "        keypoint_preds = torch.zeros(x.size(0), self.num_keypoints * 2).to(x.device)\n",
    "\n",
    "        for i in range(self.num_stages):\n",
    "            # 将关节点预测结果与特征图拼接\n",
    "            keypoint_map = keypoint_preds.view(x.size(0), self.num_keypoints, 2, 1, 1)\n",
    "            keypoint_map = keypoint_map.expand(-1, -1, -1, features.size(2), features.size(3))\n",
    "            features_with_keypoints = torch.cat([features, keypoint_map.view(x.size(0), -1, features.size(2), features.size(3))], dim=1)\n",
    "\n",
    "            # 通过回归层进行细化\n",
    "            regression_output = self.regression_layers[i](features_with_keypoints)\n",
    "\n",
    "            # 将回归输出展平并通过全连接层\n",
    "            regression_output = regression_output.view(x.size(0), -1)\n",
    "            keypoint_preds += self.fc_layers[i](regression_output)\n",
    "\n",
    "        # 返回关节点的最终位置\n",
    "        return keypoint_preds.view(x.size(0), self.num_keypoints, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接着，我们在COCO上对其进行训练。我们先导入必要的仓库以及库函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone thttps://github.com/Naman-ntc/Pytorch-Human-Pose-Estimation.git\n",
    "pip install -r requirements.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "设置好模型设置和COCO数据集的路径，开始对模型进行训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python main.py -DataConfig conf/datasets/coco.defconf -ModelConfig conf/models/DeepPose.defconf "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style=\"display: inline-block; margin-top: 10px;\">\n",
    "        <img style=\"border-radius: 0.3125em;\" \n",
    "        src=\"https://pic3.zhimg.com/80/v2-4c76b584c3e5682a2153425c132f6b2e_1440w.webp\" width=1000>\n",
    "        <div style=\"color:orange; \n",
    "        display: block;\n",
    "        color: #999;\n",
    "        padding: 2px;\"></div>\n",
    "    </div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由于训练输出较长，这里我们只展示开始训练的阶段。最后，我们在测试集上对模型进行测试，可发现最终的PCK可以达到57.5。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style=\"display: inline-block; margin-top: 10px;\">\n",
    "        <img style=\"border-radius: 0.3125em;\" \n",
    "        src=\"https://pic3.zhimg.com/80/v2-4b53bd3a8785cba4762d37f18fee0f32_1440w.jpg\" width=1000>\n",
    "        <div style=\"color:orange; \n",
    "        display: block;\n",
    "        color: #999;\n",
    "        padding: 2px;\"></div>\n",
    "    </div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<!-- 目前的困难：\n",
    "1. 姿态估计的精确度问题：由于人体姿态估计系统使用的是多种视觉信息，如果它们不能够准确地追踪到人体的每个部位，那么估计出来的结果也会不准确。\n",
    "1. 尺度不变性问题：尽管现有的人体姿态估计算法可以处理不同尺度的图像，但它们仍然存在尺度不变性问题，即当姿态发生变化时，估计的结果可能会受到影响。\n",
    "1. 光照变化问题：由于光照变化会对视觉信息产生影响，因此人体姿态估计系统可能无法准确地识别不同光照环境下的人体部位。\n",
    "1. 复杂背景问题：复杂的背景可能会干扰人体姿态估计系统的性能，使得它无法准确地识别人体部位。 -->"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
